from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

from pyrallis import field


@dataclass
class CDTTrainConfig:
    # wandb params
    project: str = "OSRL-baselines"
    group: str = None
    name: Optional[str] = None
    prefix: Optional[str] = "CDT"
    suffix: Optional[str] = ""
    logdir: Optional[str] = "logs"
    verbose: bool = True
    target_task: int = 0
    # dataset params
    outliers_percent: float = None
    noise_scale: float = None
    inpaint_ranges: Tuple[Tuple[float, float], ...] = None
    epsilon: float = None
    density: float = 1.0
    # model params
    embedding_dim: int = 512
    num_layers: int = 3
    num_heads: int = 8
    action_head_layers: int = 1
    seq_len: int = 20
    prompt_seq_len: int = 10
    episode_len: int = 300
    attention_dropout: float = 0.1
    residual_dropout: float = 0.1
    embedding_dropout: float = 0.1
    time_emb: bool = True
    lora_rank: int = 8
    fpf: bool = False
    # training params
    # task: str = "OfflineCarCircle-v0"
    task: str = "OfflineAntVelocityGymnasium-v2"
    context_encoder_path: str = "logs/context_encoder/ContextEncoder_decay_rate1.0-ba5c/ContextEncoder_decay_rate1.0-ba5c"
    data_path: str = None
    use_prompt: bool = True
    prompt_prefix: bool = True
    prompt_concat: bool = False
    prompt_dim: int = 16
    use_sa_encoder: bool = False
    train_action_encoder: bool = False
    pretrained_initialize: bool = False
    state_encoder_path: str = None
    action_encoder_path: str = None
    state_encode_dim: int = 64
    action_encode_dim: int = 32
    state_encoder_hidden_sizes: List[float] = field(default=[128, 128, 128], is_mutable=True)
    action_encoder_hidden_sizes: List[float] = field(default=[128, 128, 128], is_mutable=True)
    dataset: str = None
    learning_rate: float = 1e-4
    betas: Tuple[float, float] = (0.9, 0.999)
    weight_decay: float = 1e-4
    clip_grad: Optional[float] = 0.25
    batch_size: int = 1024
    update_steps: int = 200_000
    lr_warmup_steps: int = 500
    reward_scale: float = 0.1
    cost_scale: float = 1
    num_workers: int = 1
    # evaluation params
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((450.0, 10), (500.0, 20), (550.0, 50))  # reward, cost
    cost_limit: int = 10
    eval_episodes: int = 20
    eval_every: int = 2500
    # general params
    seed: int = 0
    device: str = "cuda:1"
    threads: int = 6
    # augmentation param
    deg: int = 4
    pf_sample: bool = False
    beta: float = 1.0
    augment_percent: float = 0.2
    # maximum absolute value of reward for the augmented trajs
    max_reward: float = 600.0
    # minimum reward above the PF curve
    min_reward: float = 1.0
    # the max drecrease of ret between the associated traj
    # w.r.t the nearest pf traj
    max_rew_decrease: float = 100.0
    # model mode params
    use_rew: bool = True
    use_cost: bool = True
    cost_transform: bool = True
    cost_prefix: bool = False
    add_cost_feat: bool = False
    mul_cost_feat: bool = False
    cat_cost_feat: bool = False
    loss_cost_weight: float = 0.02
    loss_state_weight: float = 0
    cost_reverse: bool = False
    # pf only mode param
    pf_only: bool = False
    rmin: float = 300
    cost_bins: int = 60
    npb: int = 5
    cost_sample: bool = True
    linear: bool = True  # linear or inverse
    start_sampling: bool = False
    prob: float = 0.2
    stochastic: bool = True
    init_temperature: float = 0.1
    no_entropy: bool = False
    # random augmentation
    random_aug: float = 0
    aug_rmin: float = 400
    aug_rmax: float = 500
    aug_cmin: float = -2
    aug_cmax: float = 25
    cgap: float = 5
    rstd: float = 1
    cstd: float = 0.2


@dataclass
class CDTCarCircleConfig(CDTTrainConfig):
    task: str = "OfflineCarCircle-v0"
    pass

@dataclass
class CDTCarCircleOrgConfig(CDTTrainConfig):
    seq_len: int = 10
    task: str = "OfflineCarCircle-v0"
    embedding_dim: int = 128
    device: str = "cuda:6"

@dataclass
class CDTAntRunConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 200
    # training params
    task: str = "OfflineAntRun-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((700.0, 10), (750.0, 20), (800.0, 40))
    # augmentation param
    deg: int = 3
    max_reward: float = 1000.0
    max_rew_decrease: float = 150
    device: str = "cuda:7"

@dataclass
class CDTAntRunOrgConfig(CDTTrainConfig):
    # model params
    seq_len: int = 10
    episode_len: int = 200
    # training params
    task: str = "OfflineAntRun-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((700.0, 10), (750.0, 20), (800.0, 40))
    # augmentation param
    deg: int = 3
    max_reward: float = 1000.0
    max_rew_decrease: float = 150
    device: str = "cuda:7"

    embedding_dim: int = 128


@dataclass
class CDTDroneRunConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 200
    # training params
    task: str = "OfflineDroneRun-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((400.0, 10), (500.0, 20), (600.0, 40))
    # augmentation param
    deg: int = 1
    max_reward: float = 700.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:3"


@dataclass
class CDTDroneCircleConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 300
    # training params
    task: str = "OfflineDroneCircle-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((700.0, 10), (750.0, 20), (800.0, 40))
    # augmentation param
    deg: int = 1
    max_reward: float = 1000.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:3"


@dataclass
class CDTCarRunConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 200
    # training params
    task: str = "OfflineCarRun-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((575.0, 10), (575.0, 20), (575.0, 40))
    # augmentation param
    deg: int = 0
    max_reward: float = 600.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:3"


@dataclass
class CDTAntCircleConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 500
    # training params
    task: str = "OfflineAntCircle-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((300.0, 10), (350.0, 20), (400.0, 40))
    # augmentation param
    deg: int = 2
    max_reward: float = 500.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:2"


@dataclass
class CDTBallRunConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 100
    # training params
    task: str = "OfflineBallRun-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((500.0, 10), (500.0, 20), (700.0, 40))
    # augmentation param
    deg: int = 2
    max_reward: float = 1400.0
    max_rew_decrease: float = 200
    min_reward: float = 1
    device: str = "cuda:2"


@dataclass
class CDTBallCircleConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 200
    # training params
    task: str = "OfflineBallCircle-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((700.0, 10), (750.0, 20), (800.0, 40))
    # augmentation param
    deg: int = 2
    max_reward: float = 1000.0
    max_rew_decrease: float = 200
    min_reward: float = 1
    device: str = "cuda:1"


@dataclass
class CDTCarButton1Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineCarButton1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((35.0, 20), (35.0, 40), (35.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 45.0
    max_rew_decrease: float = 10
    min_reward: float = 1
    device: str = "cuda:0"


@dataclass
class CDTCarButton2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineCarButton2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((40.0, 20), (40.0, 40), (40.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 50.0
    max_rew_decrease: float = 10
    min_reward: float = 1
    device: str = "cuda:0"


@dataclass
class CDTCarCircle1Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 500
    # training params
    task: str = "OfflineCarCircle1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((20.0, 20), (22.5, 40), (25.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 30.0
    max_rew_decrease: float = 10
    min_reward: float = 1
    device: str = "cuda:0"


@dataclass
class CDTCarCircle2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 500
    # training params
    task: str = "OfflineCarCircle2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((20.0, 20), (21.0, 40), (22.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 30.0
    max_rew_decrease: float = 10
    min_reward: float = 1
    device: str = "cuda:0"


@dataclass
class CDTCarGoal1Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineCarGoal1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((40.0, 20), (40.0, 40), (40.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 50.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:1"


@dataclass
class CDTCarGoal2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineCarGoal2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((30.0, 20), (30.0, 40), (30.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 35.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:1"


@dataclass
class CDTCarPush1Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineCarPush1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((15.0, 20), (15.0, 40), (15.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 20.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:1"


@dataclass
class CDTCarPush2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineCarPush2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((12.0, 20), (12.0, 40), (12.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 15.0
    max_rew_decrease: float = 3
    min_reward: float = 1
    device: str = "cuda:1"


@dataclass
class CDTPointButton1Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointButton1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((40.0, 20), (40.0, 40), (40.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 45.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:1"
    # state_encoder_path: str = "logs/OfflinePointButton1Gymnasium-v0-cost-10/sa_encoder_idm_loss_weight0.0-894b/sa_encoder_idm_loss_weight0.0-894b_state_AE"
    state_encoder_path: str = "logs/OfflinePointButton1Gymnasium-v0-cost-10/sa_encoder-660a/sa_encoder-660a_state_AE"

@dataclass
class CDTPointButton1OrgConfig(CDTTrainConfig):
    # model params
    seq_len: int = 10
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointButton1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((40.0, 20), (40.0, 40), (40.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 45.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:4"
    embedding_dim: int = 128


@dataclass
class CDTPointButton2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointButton2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((40.0, 20), (40.0, 40), (40.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 50.0
    max_rew_decrease: float = 10
    min_reward: float = 1
    device: str = "cuda:7"
    state_encoder_path: str = "logs/OfflinePointButton2Gymnasium-v0-cost-10/sa_encoder_idm_loss_weight0.0-b874/sa_encoder_idm_loss_weight0.0-b874_state_AE"
    # state_encoder_path: str = "logs/OfflinePointButton2Gymnasium-v0-cost-10/sa_encoder-e7e4/sa_encoder-e7e4_state_AE"
    

@dataclass
class CDTPointButton2OrgConfig(CDTTrainConfig):
    # model params
    seq_len: int = 10
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointButton2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((40.0, 20), (40.0, 40), (40.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 50.0
    max_rew_decrease: float = 10
    min_reward: float = 1
    device: str = "cuda:5"
    embedding_dim : int = 128


@dataclass
class CDTPointCircle1Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 500
    # training params
    task: str = "OfflinePointCircle1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((50.0, 20), (52.5, 40), (55.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 65.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:6"
    # state_encoder_path: str = "logs/OfflinePointCircle1Gymnasium-v0-cost-10/sa_encoder-2db5/sa_encoder-2db5_state_AE"
    state_encoder_path: str = "logs/OfflinePointCircle1Gymnasium-v0-cost-10/sa_encoder_idm_loss_weight0.0-2c5d/sa_encoder_idm_loss_weight0.0-2c5d_state_AE"

@dataclass
class CDTPointCircle1OrgConfig(CDTTrainConfig):
    # model params
    seq_len: int = 10
    episode_len: int = 500
    # training params
    task: str = "OfflinePointCircle1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((50.0, 20), (52.5, 40), (55.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 65.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:6"
    embedding_dim: int = 128


@dataclass
class CDTPointCircle2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 500
    # training params
    task: str = "OfflinePointCircle2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((45.0, 20), (47.5, 40), (50.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 55.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:7"
    # state_encoder_path: str = "logs/OfflinePointCircle2Gymnasium-v0-cost-10/sa_encoder-1dbc/sa_encoder-1dbc_state_AE"
    state_encoder_path: str = "logs/OfflinePointCircle2Gymnasium-v0-cost-10/sa_encoder_idm_loss_weight0.0-f4f4/sa_encoder_idm_loss_weight0.0-f4f4_state_AE"

@dataclass
class CDTPointCircle2OrgConfig(CDTTrainConfig):
    # model params
    seq_len: int = 10
    episode_len: int = 500
    # training params
    task: str = "OfflinePointCircle2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((45.0, 20), (47.5, 40), (50.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 55.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:7"
    embedding_dim: int = 128


@dataclass
class CDTPointGoal1Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointGoal1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((30.0, 20), (30.0, 40), (30.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 35.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:7"
    # state_encoder_path: str = "logs/OfflinePointGoal1Gymnasium-v0-cost-10/sa_encoder-82ac/sa_encoder-82ac_state_AE"
    state_encoder_path: str = "logs/OfflinePointGoal1Gymnasium-v0-cost-10/sa_encoder_idm_loss_weight0.0-62a4/sa_encoder_idm_loss_weight0.0-62a4_state_AE"

@dataclass
class CDTPointGoal1OrgConfig(CDTTrainConfig):
    # model params
    seq_len: int = 10
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointGoal1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((30.0, 20), (30.0, 40), (30.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 35.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:6"
    embedding_dim: int = 128

@dataclass
class CDTPointGoal1LargeEmbConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointGoal1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((30.0, 20), (30.0, 40), (30.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 35.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:6"
    embedding_dim: int = 512

@dataclass
class CDTPointGoal1WoRewConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointGoal1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((30.0, 20), (30.0, 40), (30.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 35.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:5"
    use_rew: bool = False
    embedding_dim: int = 512
    num_layers: int = 4


@dataclass
class CDTPointGoal2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointGoal2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((30.0, 20), (30.0, 40), (30.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 35.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:6"
    # state_encoder_path: str = "logs/OfflinePointGoal2Gymnasium-v0-cost-10/sa_encoder-668b/sa_encoder-668b_state_AE"
    state_encoder_path: str = "logs/OfflinePointGoal2Gymnasium-v0-cost-10/sa_encoder_idm_loss_weight0.0-ea8c/sa_encoder_idm_loss_weight0.0-ea8c_state_AE"

@dataclass
class CDTPointGoal2OrgConfig(CDTTrainConfig):
    # model params
    seq_len: int = 10
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointGoal2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((30.0, 20), (30.0, 40), (30.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 35.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:7"
    embedding_dim: int = 128

@dataclass
class CDTPointGoal2LargeEmbConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointGoal2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((30.0, 20), (30.0, 40), (30.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 35.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:6"
    embedding_dim: int = 512
    num_layers: int = 4

@dataclass
class CDTPointGoal2SmallSeqConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    batch_size: int = 2048
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointGoal2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((30.0, 20), (30.0, 40), (30.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 35.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:7"
    embedding_dim: int = 512

@dataclass
class CDTPointGoal230seqConfig(CDTTrainConfig):
    # model params
    seq_len: int = 30
    batch_size: int = 1024
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointGoal2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((30.0, 20), (30.0, 40), (30.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 35.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:6"
    embedding_dim: int = 512


@dataclass
class CDTPointGoal2WoRewConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    batch_size: int = 2048
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointGoal2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((30.0, 20), (30.0, 40), (30.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 35.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:7"
    use_rew: bool = False
    embedding_dim: int = 512


@dataclass
class CDTPointPush1Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointPush1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((15.0, 20), (15.0, 40), (15.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 20.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:7"
    state_encoder_path: str = "logs/OfflinePointPush1Gymnasium-v0-cost-10/sa_encoder_idm_loss_weight0.0-afc6/sa_encoder_idm_loss_weight0.0-afc6_state_AE"

@dataclass
class CDTPointPush1OrgConfig(CDTTrainConfig):
    # model params
    seq_len: int = 10
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointPush1Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((15.0, 20), (15.0, 40), (15.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 20.0
    max_rew_decrease: float = 5
    min_reward: float = 1
    device: str = "cuda:6"
    embedding_dim: int = 128


@dataclass
class CDTPointPush2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointPush2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((12.0, 20), (12.0, 40), (12.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 15.0
    max_rew_decrease: float = 3
    min_reward: float = 1
    device: str = "cuda:6"
    # state_encoder_path: str = "logs/OfflinePointPush2Gymnasium-v0-cost-10/sa_encoder-8b2f/sa_encoder-8b2f_state_AE"
    state_encoder_path: str = "logs/OfflinePointPush2Gymnasium-v0-cost-10/sa_encoder_idm_loss_weight0.0-fb2c/sa_encoder_idm_loss_weight0.0-fb2c_state_AE"

@dataclass
class CDTPointPush2OrgConfig(CDTTrainConfig):
    # model params
    seq_len: int = 10
    episode_len: int = 1000
    # training params
    task: str = "OfflinePointPush2Gymnasium-v0"
    target_returns: Tuple[Tuple[float, ...], ...] = ((12.0, 20), (12.0, 40), (12.0, 80))
    # augmentation param
    deg: int = 0
    max_reward: float = 15.0
    max_rew_decrease: float = 3
    min_reward: float = 1
    device: str = "cuda:7"
    embedding_dim: int = 128


@dataclass
class CDTAntVelocityConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineAntVelocityGymnasium-v1"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((2800.0, 20), (2800.0, 40), (2800.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 3000.0
    max_rew_decrease: float = 500
    min_reward: float = 1
    device: str = "cuda:1"

@dataclass
class CDTAntVelocityV0Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineAntVelocityGymnasium-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((2800.0, 20), (2800.0, 40), (2800.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 3000.0
    max_rew_decrease: float = 500
    min_reward: float = 1
    device: str = "cuda:5"

@dataclass
class CDTAntVelocityV2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineAntVelocityGymnasium-v2"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((2600.0, 10),)
    data_path: str = "finetune_data/OfflineAntVelocityGymnasium-v2.pkl"
    # augmentation param
    deg: int = 1
    max_reward: float = 3000.0
    max_rew_decrease: float = 500
    min_reward: float = 1
    device: str = "cuda:7"


@dataclass
class CDTHalfCheetahVelocityConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineHalfCheetahVelocityGymnasium-v1"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((3000.0, 20), (3000.0, 40), (3000.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 3000.0
    max_rew_decrease: float = 500
    min_reward: float = 1
    device: str = "cuda:7"
    # state_encoder_path: str = "logs/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/sa_encoder-e2b5/sa_encoder-e2b5_state_AE"
    # action_encoder_path: str = "logs/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/sa_encoder-e2b5/sa_encoder-e2b5_action_AE"
    # state_encoder_path: str = "logs/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/sa_encoder_idm_loss_weight0.0-9813/sa_encoder_idm_loss_weight0.0-9813_state_AE"
    # action_encoder_path: str = "logs/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/sa_encoder_idm_loss_weight0.0-9813/sa_encoder_idm_loss_weight0.0-9813_action_AE"
    # state_encoder_path: str = "logs/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/sa_encoder-6ac6/sa_encoder-6ac6_state_AE"
    # action_encoder_path: str = "logs/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/sa_encoder-6ac6/sa_encoder-6ac6_action_AE"
    state_encoder_path: str = "logs/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/sa_encoder-5d57/sa_encoder-5d57_state_AE"
    action_encoder_path: str = "logs/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/sa_encoder-5d57/sa_encoder-5d57_action_AE"

@dataclass
class CDTHalfCheetahVelocityV0Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineHalfCheetahVelocityGymnasium-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((3000.0, 20), (3000.0, 40), (3000.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 3000.0
    max_rew_decrease: float = 500
    min_reward: float = 1
    device: str = "cuda:5"

@dataclass
class CDTHalfCheetahVelocityV2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineHalfCheetahVelocityGymnasium-v2"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((2300.0, 10),)
    data_path: str = "finetune_data/OfflineHalfCheetahVelocityGymnasium-v2.pkl"
    # augmentation param
    deg: int = 1
    max_reward: float = 3000.0
    max_rew_decrease: float = 500
    min_reward: float = 1
    device: str = "cuda:5"


@dataclass
class CDTHopperVelocityConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineHopperVelocityGymnasium-v1"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((1750.0, 20), (1750.0, 40), (1750.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 2000.0
    max_rew_decrease: float = 300
    min_reward: float = 1
    device: str = "cuda:5"
    state_encoder_path: str = "logs/OfflineHopperVelocityGymnasium-v1-cost-10/sa_encoder-635d/sa_encoder-635d_state_AE"
    action_encoder_path: str = "logs/OfflineHopperVelocityGymnasium-v1-cost-10/sa_encoder-635d/sa_encoder-635d_action_AE"

@dataclass
class CDTHopperVelocityV0Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineHopperVelocityGymnasium-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((1750.0, 20), (1750.0, 40), (1750.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 2000.0
    max_rew_decrease: float = 300
    min_reward: float = 1
    device: str = "cuda:5"

@dataclass
class CDTHopperVelocityV2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineHopperVelocityGymnasium-v2"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((1600.0, 10),)
    data_path: str = "finetune_data/OfflineHopperVelocityGymnasium-v2.pkl"
    # augmentation param
    deg: int = 1
    max_reward: float = 2000.0
    max_rew_decrease: float = 300
    min_reward: float = 1
    device: str = "cuda:1"


@dataclass
class CDTSwimmerVelocityConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineSwimmerVelocityGymnasium-v1"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((160.0, 20), (160.0, 40), (160.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 250.0
    max_rew_decrease: float = 50
    min_reward: float = 1
    device: str = "cuda:2"

@dataclass
class CDTSwimmerVelocityV0Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineSwimmerVelocityGymnasium-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((160.0, 20), (160.0, 40), (160.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 250.0
    max_rew_decrease: float = 50
    min_reward: float = 1
    device: str = "cuda:2"

@dataclass
class CDTSwimmerVelocityV2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineSwimmerVelocityGymnasium-v2"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((120.0, 10),)
    data_path: str = "finetune_data/OfflineSwimmerVelocityGymnasium-v2.pkl"
    # augmentation param
    deg: int = 1
    max_reward: float = 250.0
    max_rew_decrease: float = 50
    min_reward: float = 1
    device: str = "cuda:3"

@dataclass
class CDTWalker2dVelocityConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineWalker2dVelocityGymnasium-v1"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((2800.0, 20), (2800.0, 40), (2800.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 3600.0
    max_rew_decrease: float = 800
    min_reward: float = 1
    device: str = "cuda:2"

@dataclass
class CDTWalker2dVelocityV0Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineWalker2dVelocityGymnasium-v0"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((2800.0, 20), (2800.0, 40), (2800.0, 80))
    # augmentation param
    deg: int = 1
    max_reward: float = 3600.0
    max_rew_decrease: float = 800
    min_reward: float = 1
    device: str = "cuda:2"

@dataclass
class CDTWalker2dVelocityV2Config(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineWalker2dVelocityGymnasium-v2"
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((2200.0, 10),)
    data_path: str = "finetune_data/OfflineWalker2dVelocityGymnasium-v2.pkl"
    # augmentation param
    deg: int = 1
    max_reward: float = 3600.0
    max_rew_decrease: float = 800
    min_reward: float = 1
    device: str = "cuda:4"


@dataclass
class CDTEasySparseConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineMetadrive-easysparse-v0"
    update_steps: int = 200_000
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((300.0, 10), (350.0, 20), (400.0, 40))
    # augmentation param
    deg: int = 2
    max_reward: float = 500.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:3"


@dataclass
class CDTEasyMeanConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineMetadrive-easymean-v0"
    update_steps: int = 200_000
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((300.0, 10), (350.0, 20), (400.0, 40))
    # augmentation param
    deg: int = 2
    max_reward: float = 500.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:3"


@dataclass
class CDTEasyDenseConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineMetadrive-easydense-v0"
    update_steps: int = 200_000
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((300.0, 10), (350.0, 20), (400.0, 40))
    # augmentation param
    deg: int = 2
    max_reward: float = 500.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:2"


@dataclass
class CDTMediumSparseConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineMetadrive-mediumsparse-v0"
    update_steps: int = 200_000
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((300.0, 10), (300.0, 20), (300.0, 40))
    # augmentation param
    deg: int = 0
    max_reward: float = 300.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:3"


@dataclass
class CDTMediumMeanConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineMetadrive-mediummean-v0"
    update_steps: int = 200_000
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((300.0, 10), (300.0, 20), (300.0, 40))
    # augmentation param
    deg: int = 0
    max_reward: float = 300.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:2"


@dataclass
class CDTMediumDenseConfig(CDTTrainConfig):
    # training params
    task: str = "OfflineMetadrive-mediumdense-v0"
    episode_len: int = 1000
    update_steps: int = 200_000


@dataclass
class CDTHardSparseConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineMetadrive-hardsparse-v0"
    update_steps: int = 200_000
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((300.0, 10), (350.0, 20), (400.0, 40))
    # augmentation param
    deg: int = 1
    max_reward: float = 500.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:2"


@dataclass
class CDTHardMeanConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineMetadrive-hardmean-v0"
    update_steps: int = 200_000
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((300.0, 10), (350.0, 20), (400.0, 40))
    # augmentation param
    deg: int = 1
    max_reward: float = 500.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:2"


@dataclass
class CDTHardDenseConfig(CDTTrainConfig):
    # model params
    seq_len: int = 20
    episode_len: int = 1000
    # training params
    task: str = "OfflineMetadrive-harddense-v0"
    update_steps: int = 200_000
    target_returns: Tuple[Tuple[float, ...],
                          ...] = ((300.0, 10), (350.0, 20), (400.0, 40))
    # augmentation param
    deg: int = 1
    max_reward: float = 500.0
    max_rew_decrease: float = 100
    min_reward: float = 1
    device: str = "cuda:2"


CDT_DEFAULT_CONFIG = {
    # bullet_safety_gym
    "OfflineCarCircle-v0": CDTCarCircleConfig,
    "OfflineAntRun-v0": CDTAntRunConfig,
    "OfflineDroneRun-v0": CDTDroneRunConfig,
    "OfflineDroneCircle-v0": CDTDroneCircleConfig,
    "OfflineCarRun-v0": CDTCarRunConfig,
    "OfflineAntCircle-v0": CDTAntCircleConfig,
    "OfflineBallCircle-v0": CDTBallCircleConfig,
    "OfflineBallRun-v0": CDTBallRunConfig,

    "OfflineCarCircleOrg-v0": CDTCarCircleOrgConfig,
    "OfflineAntRunOrg-v0": CDTAntRunOrgConfig,
    # safety_gymnasium
    "OfflineCarButton1Gymnasium-v0": CDTCarButton1Config,
    "OfflineCarButton2Gymnasium-v0": CDTCarButton2Config,
    "OfflineCarCircle1Gymnasium-v0": CDTCarCircle1Config,
    "OfflineCarCircle2Gymnasium-v0": CDTCarCircle2Config,
    "OfflineCarGoal1Gymnasium-v0": CDTCarGoal1Config,
    "OfflineCarGoal2Gymnasium-v0": CDTCarGoal2Config,
    "OfflineCarPush1Gymnasium-v0": CDTCarPush1Config,
    "OfflineCarPush2Gymnasium-v0": CDTCarPush2Config,
    # safety_gymnasium: point
    "OfflinePointButton1Gymnasium-v0": CDTPointButton1Config,
    "OfflinePointButton2Gymnasium-v0": CDTPointButton2Config,
    "OfflinePointCircle1Gymnasium-v0": CDTPointCircle1Config,
    "OfflinePointCircle2Gymnasium-v0": CDTPointCircle2Config,
    "OfflinePointGoal1Gymnasium-v0": CDTPointGoal1Config,
    "OfflinePointGoal2Gymnasium-v0": CDTPointGoal2Config,
    "OfflinePointPush1Gymnasium-v0": CDTPointPush1Config,
    "OfflinePointPush2Gymnasium-v0": CDTPointPush2Config,
    "OfflinePointGoal1WoRewGymnasium-v0": CDTPointGoal1WoRewConfig,
    "OfflinePointGoal1LargeEmbGymnasium-v0": CDTPointGoal1LargeEmbConfig,
    "OfflinePointGoal2WoRewGymnasium-v0": CDTPointGoal2WoRewConfig,
    "OfflinePointGoal2LargeEmbGymnasium-v0": CDTPointGoal2LargeEmbConfig,
    "OfflinePointGoal2SmallSeqGymnasium-v0": CDTPointGoal2SmallSeqConfig,
    "OfflinePointGoal230seqGymnasium-v0": CDTPointGoal230seqConfig,
    "OfflinePointButton1OrgGymnasium-v0": CDTPointButton1OrgConfig,
    "OfflinePointButton2OrgGymnasium-v0": CDTPointButton2OrgConfig,
    "OfflinePointCircle1OrgGymnasium-v0": CDTPointCircle1OrgConfig,
    "OfflinePointCircle2OrgGymnasium-v0": CDTPointCircle2OrgConfig,
    "OfflinePointGoal1OrgGymnasium-v0": CDTPointGoal1OrgConfig,
    "OfflinePointGoal2OrgGymnasium-v0": CDTPointGoal2OrgConfig,
    "OfflinePointPush1OrgGymnasium-v0": CDTPointPush1OrgConfig,
    "OfflinePointPush2OrgGymnasium-v0": CDTPointPush2OrgConfig,
    # safety_gymnasium: velocity
    "OfflineAntVelocityGymnasium-v1": CDTAntVelocityConfig,
    "OfflineAntVelocityGymnasium-v0": CDTAntVelocityV0Config,
    "OfflineHalfCheetahVelocityGymnasium-v1": CDTHalfCheetahVelocityConfig,
    "OfflineHalfCheetahVelocityGymnasium-v0": CDTHalfCheetahVelocityV0Config,
    "OfflineHopperVelocityGymnasium-v1": CDTHopperVelocityConfig,
    "OfflineSwimmerVelocityGymnasium-v1": CDTSwimmerVelocityConfig,
    "OfflineWalker2dVelocityGymnasium-v1": CDTWalker2dVelocityConfig,
    "OfflineHopperVelocityGymnasium-v0": CDTHopperVelocityV0Config,
    "OfflineSwimmerVelocityGymnasium-v0": CDTSwimmerVelocityV0Config,
    "OfflineWalker2dVelocityGymnasium-v0": CDTWalker2dVelocityV0Config,
    "OfflineAntVelocityGymnasium-v2": CDTAntVelocityV2Config,
    "OfflineHalfCheetahVelocityGymnasium-v2": CDTHalfCheetahVelocityV2Config,
    "OfflineHopperVelocityGymnasium-v2": CDTHopperVelocityV2Config,
    "OfflineSwimmerVelocityGymnasium-v2": CDTSwimmerVelocityV2Config,
    "OfflineWalker2dVelocityGymnasium-v2": CDTWalker2dVelocityV2Config,
    # safe_metadrive
    "OfflineMetadrive-easysparse-v0": CDTEasySparseConfig,
    "OfflineMetadrive-easymean-v0": CDTEasyMeanConfig,
    "OfflineMetadrive-easydense-v0": CDTEasyDenseConfig,
    "OfflineMetadrive-mediumsparse-v0": CDTMediumSparseConfig,
    "OfflineMetadrive-mediummean-v0": CDTMediumMeanConfig,
    "OfflineMetadrive-mediumdense-v0": CDTMediumDenseConfig,
    "OfflineMetadrive-hardsparse-v0": CDTHardSparseConfig,
    "OfflineMetadrive-hardmean-v0": CDTHardMeanConfig,
    "OfflineMetadrive-harddense-v0": CDTHardDenseConfig
}
